import datetime
import pickle

import torch


def get_dict_start_time_index():
    dict_start_time_index = {}
    value = 1
    for month in range(1, 13):
        month = str(month)
        month = month if len(month) >= 2 else '0' + month
        for day in range(1, 32):
            day = str(day)
            day = day if len(day) >= 2 else '0' + day
            for hour in range(0, 25):
                hour = str(hour)
                hour = hour if len(hour) >= 2 else '0' + hour
                key = f'{month}_{day}_{hour}'
                dict_start_time_index[key] = value
                value += 1
    return dict_start_time_index


def write(path, target, mode):
    with open(path, mode) as f:
        f.write(target + '\n')


def transfer_torch(target):
    return torch.Tensor(target).to(args.device)


def convert_date(target):
    target = target.split('+')[0]
    target = target.split('Z')[0].split('T')

    v_date = list(map(int, target[0].split('-')))
    v_time = list(map(int, target[1].split('.')[0].split(':')))
    v_sec = list(map(int, target[1].split('.')[1].split('.')))

    return datetime.datetime(v_date[0], v_date[1], v_date[2], v_time[0],
                             v_time[1], v_time[2], v_sec[0])


def read_pickle(target_path):
    with open(target_path, 'rb') as f:
        target_file = pickle.load(f)
    return target_file


def read_mapping_item_id(target_path):
    old2new = {}
    with open(target_path, 'r') as f:
        mapping_list = f.readlines()
        for i in range(len(mapping_list)):
            mapping_val = mapping_list[i].replace('\n', '').split(',')
            old2new[int(mapping_val[0])] = int(mapping_val[1])
    return old2new


def get_parameter(*models):
    parameter_list = []
    for cur_model in models:
        parameter_list += list(cur_model.parameters())

    return parameter_list


def convert_list_to_dict(target_list, start_idx):
    cur_id = start_idx
    old_to_new = {}
    new_to_old = {}

    for elements in target_list:
        old_to_new[elements] = cur_id
        new_to_old[cur_id] = elements
        cur_id += 1

    return old_to_new, new_to_old


def get_acc(true_label, predicted_label):
    output2binary = (predicted_label > 0.5).float()
    correct = (output2binary == true_label).float().sum()

    return correct / output2binary.size(0)


def load_weight(to_model, from_model, weight_list):
    for weight in weight_list:
        run = f'{to_model}.{weight}.data={from_model}[\'{weight}\']'
        exec(run)


def load_csv(target_path):
    with open(target_path, 'r') as f:
        line_list = f.readlines()
    return line_list


if __name__ == '__main__':
    # convert_list2dict
    tmp_list = [11, 23, 43, 53]
    new_tmp_dict = convert_list_to_dict(tmp_list, 1)
    print(new_tmp_dict)

    pass
